#!/usr/bin/python
#-*- coding: UTF-8

import sys
import time

from coldata import *
import mph


def tab(s, tabs=1):
	return '\t' * tabs + s


def state2str(codepoints):
	'''transforms state [ "0001", "0002"] into string "state_0001_0002"'''
	return 'state_' + '_'.join('%06X' % int(codepoint, base=16) for codepoint in codepoints)


def str2state(state):
	'''transforms string "state_0001_0002" into state [ "0001", "0002"]'''
	return state[len('state_'):].split('_')


def find_children(state, contractions):
	'''return list of child states, lookup is done on contractions'''
	children = []
	for s, w in contractions:
		if len(s) > len(state) and s[:len(state)] == state:
			children.append((s, w))
	return children


def find_parent(state, states):
	'''return parent state, parent might be intermediate or root'''
	for s, w in states:
		s = str2state(s)
		if len(s) == len(state) - 1 and state[:-1] == s:
			return state2str(s)


def final_state(state, contractions):
	'''test if state is final - no child states'''
	return not find_children(state, contractions)


def root_state(state):
	'''test if state is root - no parent states'''
	return len(state) == 1


def collect_states(contractions, codepoints):
	'''collect entire list of states: root, intermediate and final'''
	states = set()
	max_level = 0
	for points, _ in contractions:
		if len(points) < 2:
			continue

		for i in xrange(len(points) - 1):
			states.add(state2str(points[:i + 1]))
		states.add(state2str(points))

		max_level = max(len(points), max_level)

	s = [ (state, -(i + 1)) for i, state in enumerate(states) ]
	return sorted(s, key=itemgetter(1)), max_level


def find_closest_parent(state, states, contractions, codepoints):
	'''return closest parent *with assigned weight*'''
	closest_parent = []
	for s, w in states:
		s = str2state(s)
		if len(s) <= len(state) and state[:len(s)] == s:
			if not find_weight(s, contractions) and not find_weight(s, codepoints):
				continue

			closest_parent = len(s) > len(closest_parent) and s or closest_parent

	return closest_parent or None  # return None instead of empty list


def gen_roots_lookup(tag):
	tag = (tag + '_roots').upper()  # internal tag

	print tab('if (w == 0) { /*  first entry, root states */')
	print tab('uint32_t state = nu_udb_lookup_value(u, %s_G, %s_G_SIZE,' % (tag, tag), 2)
	print tab('%s_VALUES_C, %s_VALUES_I);' % (tag, tag), 3)
	print
	print tab('if (state != 0) {', 2)
	print tab('return -state; /* VALUES_I store negated (positive) states */', 3)
	print tab('}', 2)
	print tab('}')
	print


def expand_children(state, states, contractions, codepoints, level=1):
	'''produce switch for child states'''

	def fallback(state, states, contractions, codepoints):
		'''produce fallback code when neither of child states
		doesn't fit next codepoint and state is need to be reverted back up
		to closest parent with assigned weight'''
		parent = find_closest_parent(state, states, contractions, codepoints)
		assert(parent is not None)
		weight = find_weight(parent, contractions) or find_weight(parent, codepoints)
		assert(weight is not None)
		distance = len(state) - len(parent) + 1
		assert(distance >= 1)

		print tab('*w = %d;' % (distance), level)
		print tab('return 0x%06X;' % (weight), level)

	children = set([tuple(child[:len(state) + 1]) for child, _ in find_children(state, contractions)])
	children = list(list(child) for child in children)
	assert(children)

	print tab('switch (u) {', level)

	for i, next_state in enumerate(children):
		assert(len(next_state) > len(state))

		check_codepoint = next_state[-1]
		weight = final_state(next_state, contractions) and ('0x%06X'  % find_weight(next_state, contractions)) or state2str(next_state)

		print tab('case 0x%06X: return %s; ' % (int(check_codepoint, base=16), weight), level)

	print tab('}', level)
	print

	fallback(state, states, contractions, codepoints)


def expand_intermediate_states(states, contractions, codepoints):
	'''produce c code for intermediate states'''

	def split_roots(roots):
		'''kind of binary split of list into left side, right side
		and state at the middle. list is need to be ordered'''
		middle = roots[len(roots) / 2]
		left = roots[:len(roots) / 2]
		right = roots[len(roots) / 2 + 1:]
		return middle, left, right

	def intermediate_states(states, contractions):
		'''filter intermediate states from complete list of states'''
		return [ state for state, _ in states if not final_state(str2state(state), contractions) ]

	def expand_intermediate_state(state, left, right, states, contractions, codepoints, level=2):
		'''produce code for a single state recursively'''
		state_str = state2str(state)

		print tab('if (weight == %s) {' % (state_str), level)
		expand_children(state, states, contractions, codepoints, level=level+1)
		print tab('}', level)

		if left:
			print tab('else if (weight < %s) {' % (state_str), level)
			lmiddle, lleft, lright = split_roots(left)
			expand_intermediate_state(str2state(lmiddle), lleft, lright, states, contractions, codepoints, level=level+1)
			print tab('}', level)

		if right:
			print tab('else { /* weight > %s */' % (state_str), level)
			rmiddle, rleft, rright = split_roots(right)
			expand_intermediate_state(str2state(rmiddle), rleft, rright, states, contractions, codepoints, level=level+1)
			print tab('}', level)

	# list is need to be ordered for binary search to work
	assert(states == sorted(states, key=itemgetter(1)))

	roots = intermediate_states(states, contractions)
	middle, left, right = split_roots(roots)

	print tab('if (w != 0) { /* re-entry, intermediate states */')
	print tab('int32_t weight = *w;', 2)
	print tab('*w = 0;', 2);
	print

	expand_intermediate_state(str2state(middle), left, right, states, contractions, codepoints)

	print tab('}')
	print


def gen_header(tag, contractions):
	'''produce info header'''
	print '''/* Automatically generated file (contractions-toc), %d
 *
 * Tag          : %s
 * Contractions : %d
 */''' % (time.time(), tag, len(contractions))
	print


def gen_includes():
	'''produce all required includes to compile generated code'''
	print '#include <stdint.h>'
	print
	print '#include "udb.h"'
	print


def gen_consts(tag, contractions, codepoints):
	print 'const size_t %s_CONTRACTIONS = %d; /* contractions included in switch */' % (tag.upper(), len(contractions))
	print 'const size_t %s_CODEPOINTS = %d; /* complementary codepoints number */' % (tag.upper(), len(codepoints))
	print


def gen_roots_mph(tag, states, compact=False):
	'''produce MPH sructures for root codepoints'''
	def root_states(states):
		'''filter root states from complete states list'''
		roots = [ (state, weight) for state, weight in states if root_state(str2state(state)) ]
		return roots

	roots = root_states(states)
	d = {}
	for state, weight in roots:
		codepoints = str2state(state)
		assert(len(codepoints) == 1)
		d[codepoints[0]] = (codepoints[0], -weight)

	tag = (tag + '_roots').upper()  # internal tag

	(G, V) = mph.create_minimal_perfect_hash(d)
	mph.gen_G(tag, G)
	mph.gen_values(tag, G, V, compact)


def gen_switch(prefix, states, contractions, codepoints):
	'''produce switch entry point'''
	print '/* MPH lookup for root codepoints + binary search on balanced tree'
	print ' * for intermediate states */'
	print 'int32_t %s_weight_switch(uint32_t u, int32_t *w, void *context) {' % (prefix)
	print tab('(void)(context);')
	print

	gen_roots_lookup(tag)
	expand_intermediate_states(states, contractions, codepoints)

	# impossible weight because this special case is need
	# to be handled before entering switch
	# thus indicates that switch didn't find a weight
	print tab('return 0;')
	print '}'


def gen_states(states, contractions):
	'''produce defines with state weights. normally all states
	produced would have negative weight (intermediate and root states).
	final states are resolved directly into weight w/o extra switch'''
	for state_str, weight in states:
		state = str2state(state_str)
		# don't print final states - they will be embedded as numbers
		if not final_state(state, contractions):
			assert(int(weight) < 0)
			print '#define %s %d' % (state_str, weight)
	print


def usage():
	print 'usage: ' + sys.argv[0] + ' [CODEPOINTS] [CONTRACTIONS] [TAG] [BMP_ONLY]'
	print
	print '  [CODEPOINTS]   - filename with list of codepoints'
	print '  [CONTRACTIONS] - filename with list of contractions from the same collation'
	print '  [TAG]          - prefix to weighting switch'
	print '  [BMP_ONLY]     - flag to indicate if set is BMP-only, 0 or 1 (false or true)'


if __name__ == '__main__':
	if len(sys.argv) < 5:
		usage()
		sys.exit(1)

	codepoints_file, contractions_file = sys.argv[1], sys.argv[2]
	tag = sys.argv[3]
	bmp_only = bool(int(sys.argv[4]))

	codepoints, contractions = collect_contractions(codepoints_file, contractions_file)
	states, _ = collect_states(contractions, codepoints)

	gen_header(tag, contractions)
	gen_includes()
	gen_consts(tag, contractions, codepoints)
	gen_states(states, contractions)
	gen_roots_mph(tag, states, bmp_only)
	gen_switch(tag, states, contractions, codepoints)
